{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# System of non-linear equations - Coupled chemical equilibria\n",
    "Author: Björn Dahlgren, Applied Physical Chemistry, KTH Royal Insitiute of Technology\n",
    "\n",
    "In this example we will study the equilibria between aqueous cupric ions and ammonia.\n",
    "We will use [ChemPy](https://github.com/bjodah/chempy) which is a Python package collecting functions and classes useful for chemistry related problems. We will also make use of SymPy for manipulating and inspecting the formulae encountered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "from chempy import atomic_number\n",
    "from chempy.chemistry import Species, Equilibrium\n",
    "from chempy.equilibria import EqSystem, NumSysLin, NumSysLog, NumSysSquare\n",
    "from IPython.display import Latex, display\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "def show(s):  # convenience function\n",
    "    display(Latex('$'+s+'$'))\n",
    "import sympy; import chempy; print('SymPy: %s, ChemPy: %s' % (sympy.__version__, chempy.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define our species with names and composition, ChemPy can parse chemical formulae:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NH3_complexes = ['CuNH3+2', 'Cu(NH3)2+2', 'Cu(NH3)3+2', 'Cu(NH3)4+2', 'Cu(NH3)5+2']\n",
    "OH_complexes = ['Cu2(OH)2+2', 'Cu(OH)3-', 'Cu(OH)4-2']\n",
    "substances = [\n",
    "    Species.from_formula(n) for n in ['H+', 'OH-', 'NH4+', 'NH3', 'H2O', 'Cu+2'] + \n",
    "        NH3_complexes + OH_complexes + ['Cu(OH)2(s)']\n",
    "] #, CuOHp, CuOH2,\n",
    "substance_names = [s.name for s in substances]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how the species are pretty-printed:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show(', '.join([s.latex_name for s in substances]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define some initial concentrations. We will consider different amount of added ammonia in 10 mM solutions of $Cu^{2+}$:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init_conc = defaultdict(float, {'H+': 1e-7, 'OH-': 1e-7, 'NH4+': 0,\n",
    "                                'NH3': 1.0, 'Cu+2': 1e-2, 'H2O': 55.5})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let us define the equilibria, data are from course material at Applied Physical Chemistry, KTH Royal Insitiute of Technology."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "H2O_c = init_conc['H2O']\n",
    "w_autop = Equilibrium({'H2O': 1}, {'H+': 1, 'OH-': 1}, 10**-14/H2O_c)\n",
    "NH4p_pr = Equilibrium({'NH4+': 1}, {'H+': 1, 'NH3': 1}, 10**-9.26)\n",
    "CuOH2_s = Equilibrium({'Cu(OH)2(s)': 1}, {'Cu+2': 1, 'OH-': 2}, 10**-18.8)\n",
    "CuOH_B3 = Equilibrium({'Cu(OH)2(s)': 1, 'OH-': 1}, {'Cu(OH)3-': 1}, 10**-3.6)\n",
    "CuOH_B4 = Equilibrium({'Cu(OH)2(s)': 1, 'OH-': 2}, {'Cu(OH)4-2': 1}, 10**-2.7)\n",
    "Cu2OH2 = Equilibrium({'Cu+2': 2, 'H2O': 2}, {'Cu2(OH)2+2': 1, 'H+': 2}, 10**-10.6 / H2O_c**2)\n",
    "CuNH3_B1 = Equilibrium({'CuNH3+2': 1}, {'Cu+2': 1, 'NH3': 1}, 10**-4.3)\n",
    "CuNH3_B2 = Equilibrium({'Cu(NH3)2+2': 1}, {'Cu+2': 1, 'NH3': 2}, 10**-7.9)\n",
    "CuNH3_B3 = Equilibrium({'Cu(NH3)3+2': 1}, {'Cu+2': 1, 'NH3': 3}, 10**-10.8)\n",
    "CuNH3_B4 = Equilibrium({'Cu(NH3)4+2': 1}, {'Cu+2': 1, 'NH3': 4}, 10**-13.0)\n",
    "CuNH3_B5 = Equilibrium({'Cu(NH3)5+2': 1}, {'Cu+2': 1, 'NH3': 5}, 10**-12.4)\n",
    "equilibria = w_autop, NH4p_pr, CuNH3_B1, CuNH3_B2, CuNH3_B3, CuNH3_B4, CuNH3_B5, Cu2OH2, CuOH2_s, CuOH_B3, CuOH_B4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see if we can print ``equilibria`` in a human-readable form:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show(', '.join([s.latex_name for s in substances]))\n",
    "show('~')\n",
    "from math import log10\n",
    "for eq in equilibria:\n",
    "    ltx = eq.latex(dict(zip(substance_names, substances)))\n",
    "    show(ltx + '~'*(80-len(ltx)) + 'lgK = {0:12.5g}'.format(log10(eq.param)))  # latex table would be better..."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To keep our numerical treatment as simple as possible we will try to avoid representing\n",
    "$Cu(OH)_2(s)$ explicitly (which is present in the three last equilibria). This is because the\n",
    "system of equations change when precipitation sets in.\n",
    "\n",
    "However, we do want to keep the two last equilibria, therefore we rewrite\n",
    "those using the dissolution equilibria them only using dissolved species:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_eqs = CuOH2_s - CuOH_B3, CuOH2_s - CuOH_B4\n",
    "[str(_) for _ in new_eqs]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's time to exclude the precipitate species and replace the last three equilibria with our two new ones:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#skip_subs, skip_eq = (4, 4) # (0, 0), (1, 1), (3, 3), (4, 4), (11, 9)\n",
    "skip_subs, skip_eq = (1, 3)\n",
    "simpl_subs = substances[:-skip_subs]\n",
    "simpl_eq = equilibria[:-skip_eq] + new_eqs\n",
    "simpl_c0 = {k.name: init_conc[k.name] for k in simpl_subs}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "From the law of mass action we can from the equilbria and from the preservation of mass and charge formulate a non-linear system of equations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy as sp\n",
    "import numpy as np\n",
    "sp.init_printing()\n",
    "eqsys = EqSystem(simpl_eq, simpl_subs)\n",
    "x, i, Ks = sp.symarray('x', eqsys.ns), sp.symarray('i', eqsys.ns), sp.symarray('K', eqsys.nr)\n",
    "params = np.concatenate((i, Ks))\n",
    "numsys_lin = NumSysLin(eqsys, backend=sp)\n",
    "numsys_lin.f(x, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It turns out that the success of the numerical root finding process for above system of equations is terribly sensitive on the choice of the initial guess. We therefore reformulate the equations in terms of the logarithm of the concentrations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numsys_log = NumSysLog(eqsys, backend=sp)\n",
    "f = numsys_log.f(x, params)\n",
    "f"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can take a peek on the jacobian of this vector:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sp.Matrix(1, len(f), lambda _, q: f[q]).jacobian(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The preservation equations of mass and charge actually contain a redundant equation, so currently our system is over-determined:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(f), eqsys.ns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We could cast the preservation equations into [reduced row echelon form](https://en.wikipedia.org/wiki/Row_echelon_form) (which would remove one equation), but for now we'll leave this be and rely on the Levenberg-Marquardt algorithm to solve our problem in a least-squares sense. (Levenberg-Marquardt uses [QR-decomposition](https://en.wikipedia.org/wiki/QR_decomposition) internally for which it is acceptable to have overdetermined systems).\n",
    "\n",
    "Let's solve the equations for our initial concentrations:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "C, sol, sane = eqsys.root(simpl_c0, NumSys=NumSysLog)\n",
    "assert sol['success'] and sane\n",
    "C"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great, let's now vary the initial concentration of $NH_3$ and plot the equilibrium concentrations of our species:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "plt.figure(figsize=(20,8))\n",
    "NH3_varied = np.logspace(-4, 0, 200)\n",
    "Cout_logC, extra, success = eqsys.roots(\n",
    "    simpl_c0, NH3_varied, 'NH3', NumSys=NumSysLog, plot_kwargs={'latex_names': True})\n",
    "all(success), sum([nfo['nfev'] for nfo in extra['info']]), sum([nfo['njev'] for nfo in extra['info']])\n",
    "_ = plt.gca().set_ylim([1e-10, 60])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "But the above diagram is only true if we are below the solubility limit of our neglected $\\rm Cu(OH)_2(s)$.\n",
    "\n",
    "Let's plot the solubility product in the same diagram:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sol_prod = Cout_logC[:, eqsys.as_substance_index('Cu+2')]*Cout_logC[:, eqsys.as_substance_index('OH-')]**2\n",
    "plt.figure(figsize=(20,6))\n",
    "plt.loglog(NH3_varied, sol_prod, label='[$%s$][$%s$]$^2$' % (eqsys.substances['Cu+2'].latex_name,\n",
    "                                                             eqsys.substances['OH-'].latex_name))\n",
    "plt.loglog(NH3_varied, Cout_logC[:, eqsys.as_substance_index('H+')], ls=':',\n",
    "           label='[$%s$]' % eqsys.substances['H+'].latex_name)\n",
    "plt.loglog([NH3_varied[0], NH3_varied[-1]], [10**-18.8, 10**-18.8], 'k--', label='$K_{sp}(Cu(OH)_2(s))$')\n",
    "plt.xlabel('[$NH_3$]')\n",
    "_ = plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that for a ammonia concentration exceeding ~500-600 mM we would not precipitate $Cu(OH)_2(s)$ even though our pH is quite high (almost 12).\n",
    "\n",
    "We have solved the above system of equations for the *logarithm* of the concentrations. How big are our absolute and relative errors compared to the linear system? Let's plot them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(1, 2, figsize=(20, 8))\n",
    "eqsys.plot_errors(Cout_logC, simpl_c0, NH3_varied, 'NH3', axes=axes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not bad. So the problem is essentially solved. But let's say that it is very important to know the exact position of the intersection for the solutbility limit. We can locate it using the secant method:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.optimize import newton\n",
    "convergence = []\n",
    "def sol_lim(c_NH3):\n",
    "    c0 = simpl_c0.copy()\n",
    "    c0['NH3'] = c_NH3\n",
    "    C, sol, sane = eqsys.root(c0, NumSys=NumSysLog)\n",
    "    assert sol['success'] and sane\n",
    "    prod = C[eqsys.as_substance_index('Cu+2')]*C[eqsys.as_substance_index('OH-')]**2\n",
    "    discrepancy = prod/10**-18.8 - 1\n",
    "    convergence.append(discrepancy)\n",
    "    return discrepancy\n",
    "Climit_NH3 = newton(sol_lim, 0.5)\n",
    "convergence = np.array(convergence).reshape((len(convergence), 1))\n",
    "print(convergence)\n",
    "Climit_NH3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For fun, let's see what the equation system looks like if we canonicalize it by transforming the equations for equibliria and the equations for the preservation relations to their respective reduced row echelon form:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "numsys_log_rref = NumSysLog(eqsys, rref_equil=True, rref_preserv=True, backend=sp)\n",
    "rf = numsys_log_rref.f(x, params)\n",
    "rf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So the Jacobian should be considerably more diagonally dominant now:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sp.Matrix(1, len(rf), lambda _, q: rf[q]).jacobian(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And let's see if this system converges as well:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(20,8))\n",
    "out = eqsys.roots(simpl_c0, NH3_varied, 'NH3', rref_equil=True,\n",
    "                  rref_preserv=True, NumSys=NumSysLog, plot_kwargs={'latex_names': True})\n",
    "_ = plt.gca().set_ylim([1e-10, 60])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sinvce version 0.2.0 of chempy there is support for automatic reformulation of the system of equations when precipitation occurs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "full_eqsys = EqSystem(equilibria[:-2] + new_eqs, substances)\n",
    "full_numsys_log_rref = NumSysLog(full_eqsys, rref_equil=False, rref_preserv=False, precipitates=[False], backend=sp)\n",
    "full_x, full_i, full_Ks = sp.symarray('x', full_eqsys.ns), sp.symarray('i', full_eqsys.ns), sp.symarray('K', full_eqsys.nr)\n",
    "full_rf = full_numsys_log_rref.f(full_x, np.concatenate((full_i, full_Ks)))\n",
    "full_rf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def solve_and_plot_full(NumSys, plot_kwargs={'latex_names': True}, **kwargs):\n",
    "    plt.figure(figsize=(18, 7))\n",
    "    result = full_eqsys.roots(init_conc, NH3_varied, 'NH3', NumSys=NumSys, plot_kwargs=plot_kwargs or {}, **kwargs)\n",
    "    \n",
    "    try:\n",
    "        cur_val = None\n",
    "        for idx, condition in enumerate([nfo['intermediate_info'][0]['conditions'] for nfo in result[1]['info']]):\n",
    "            any_precip = condition != (False,)*len(condition)\n",
    "            if cur_val is None:\n",
    "                if any_precip:\n",
    "                    onset = idx\n",
    "            elif cur_val == any_precip:\n",
    "                pass\n",
    "            else:\n",
    "                if any_precip:\n",
    "                    onset = idx\n",
    "                else:\n",
    "                    plt.axvspan(NH3_varied[onset], NH3_varied[idx], facecolor='gray', alpha=0.1)\n",
    "                    onset = None\n",
    "            cur_val = any_precip\n",
    "        if onset is not None:\n",
    "            plt.axvspan(NH3_varied[onset], NH3_varied[-1], facecolor='gray', alpha=0.1)\n",
    "    except KeyError:\n",
    "        pass\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xout_log, sols_log, sane_log = solve_and_plot_full(NumSysLog, solver='scipy', tol=1e-10, conditional_maxiter=30,\n",
    "                                                   rref_equil=True, rref_preserv=True, method='lm')\n",
    "plt.gca().set_ylim([1e-10, 60])\n",
    "plt.gca().set_xscale('log')\n",
    "plt.gca().set_yscale('log')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the numerical solution is not perfect (it could probably be improved by scaling the components individually). But the principle is clear: we can solve the solve non-linear system of equations using this method.\n",
    "\n",
    "Let's see if we can gain some more insight here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sum_species(x, species, substance_names, weights=None):\n",
    "    accum = np.zeros(x.shape[0])\n",
    "    if weights is None:\n",
    "        weights = [1]*len(substance_names)\n",
    "    for idx in map(substance_names.index, species):\n",
    "        accum += x[:, idx]*weights[idx]\n",
    "    return accum\n",
    "\n",
    "def plot_groups(varied, x):\n",
    "    substance_names = list(full_eqsys.substances.keys())\n",
    "    weights = [s.composition.get(atomic_number('Cu'), ) for s in full_eqsys.substances.values()]\n",
    "    amines = sum_species(x, NH3_complexes, substance_names, weights)\n",
    "    hydroxides = sum_species(x, OH_complexes, substance_names, weights)\n",
    "    free = sum_species(x, ['Cu+2'], substance_names, weights)\n",
    "    precip = sum_species(x, ['Cu(OH)2(s)'], substance_names, weights)\n",
    "    \n",
    "    tot = amines + hydroxides + free + precip\n",
    "\n",
    "    plt.figure(figsize=(13.7, 7))\n",
    "    plt.plot(NH3_varied, tot, label='tot')\n",
    "    plt.plot(NH3_varied, amines, label='amines')\n",
    "    plt.plot(NH3_varied, hydroxides, label='hydroxides')\n",
    "    plt.plot(NH3_varied, free, label='free')\n",
    "    plt.plot(NH3_varied, precip, label='precip')\n",
    "    plt.legend(loc='best')\n",
    "    plt.gca().set_xscale('log')\n",
    "\n",
    "plot_groups(NH3_varied, xout_log)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Without any precipitates (we force the system to not precipitate any solids, applicable for short timescales for example):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xout_static, sols_static, sane_static = solve_and_plot_full(\n",
    "    NumSysLog, solver='scipy', tol=1e-12, # , NumSysSquare\n",
    "    neqsys_type='static_conditions',\n",
    "    rref_equil=True, rref_preserv=True,\n",
    "    precipitates=(False,), method='lm', plot_kwargs=None\n",
    ")\n",
    "plt.gca().set_xscale('log')\n",
    "plt.gca().set_yscale('log')\n",
    "plt.gca().set_ylim([1e-10, 60])\n",
    "\n",
    "plot_groups(NH3_varied, xout_static)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
